community[patch]: Support Streaming in Azure Machine Learning (#18246)

- [x] **PR title**: "community: Support streaming in Azure ML and few
naming changes"

- [x] **PR message**:
- **Description:** Added support for streaming for azureml_endpoint.
Also, renamed and AzureMLEndpointApiType.realtime to
AzureMLEndpointApiType.dedicated. Also, added new classes
CustomOpenAIChatContentFormatter and CustomOpenAIContentFormatter and
updated the classes LlamaChatContentFormatter and LlamaContentFormatter
to now show a deprecated warning message when instantiated.

---------

Co-authored-by: Sachin Paryani <saparan@microsoft.com>
Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Sachin Paryani 2024-03-28 16:38:20 -07:00 committed by GitHub
parent ecb11a4a32
commit 25c9f3d1d1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 285 additions and 76 deletions

View File

@ -40,7 +40,7 @@
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n", "You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
"\n", "\n",
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n", "* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
"* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n", "* `endpoint_api_type`: Use `endpoint_type='dedicated'` when deploying models to **Dedicated endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
"* `endpoint_api_key`: The API key provided by the endpoint" "* `endpoint_api_key`: The API key provided by the endpoint"
] ]
}, },
@ -52,9 +52,9 @@
"\n", "\n",
"The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n", "The `content_formatter` parameter is a handler class for transforming the request and response of an AzureML endpoint to match with required schema. Since there are a wide range of models in the model catalog, each of which may process data differently from one another, a `ContentFormatterBase` class is provided to allow users to transform data to their liking. The following content formatters are provided:\n",
"\n", "\n",
"* `LLamaChatContentFormatter`: Formats request and response data for LLaMa2-chat\n", "* `CustomOpenAIChatContentFormatter`: Formats request and response data for models like LLaMa2-chat that follow the OpenAI API spec for request and response.\n",
"\n", "\n",
"*Note: `langchain.chat_models.azureml_endpoint.LLamaContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.LLamaChatContentFormatter`.*\n", "*Note: `langchain.chat_models.azureml_endpoint.LlamaChatContentFormatter` is being deprecated and replaced with `langchain.chat_models.azureml_endpoint.CustomOpenAIChatContentFormatter`.*\n",
"\n", "\n",
"You can implement custom content formatters specific for your model deriving from the class `langchain_community.llms.azureml_endpoint.ContentFormatterBase`." "You can implement custom content formatters specific for your model deriving from the class `langchain_community.llms.azureml_endpoint.ContentFormatterBase`."
] ]
@ -65,20 +65,7 @@
"source": [ "source": [
"## Examples\n", "## Examples\n",
"\n", "\n",
"The following section cotain examples about how to use this class:" "The following section contains examples about how to use this class:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.azureml_endpoint import (\n",
" AzureMLEndpointApiType,\n",
" LlamaChatContentFormatter,\n",
")\n",
"from langchain_core.messages import HumanMessage"
] ]
}, },
{ {
@ -105,14 +92,17 @@
} }
], ],
"source": [ "source": [
"from langchain_community.chat_models.azureml_endpoint import LlamaContentFormatter\n", "from langchain_community.chat_models.azureml_endpoint import (\n",
" AzureMLEndpointApiType,\n",
" CustomOpenAIChatContentFormatter,\n",
")\n",
"from langchain_core.messages import HumanMessage\n", "from langchain_core.messages import HumanMessage\n",
"\n", "\n",
"chat = AzureMLChatOnlineEndpoint(\n", "chat = AzureMLChatOnlineEndpoint(\n",
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n", " endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n", " endpoint_api_type=AzureMLEndpointApiType.dedicated,\n",
" endpoint_api_key=\"my-api-key\",\n", " endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaChatContentFormatter(),\n", " content_formatter=CustomOpenAIChatContentFormatter(),\n",
")\n", ")\n",
"response = chat.invoke(\n", "response = chat.invoke(\n",
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n", " [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
@ -137,7 +127,7 @@
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n", " endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n", " endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
" endpoint_api_key=\"my-api-key\",\n", " endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaChatContentFormatter,\n", " content_formatter=CustomOpenAIChatContentFormatter,\n",
")\n", ")\n",
"response = chat.invoke(\n", "response = chat.invoke(\n",
" [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n", " [HumanMessage(content=\"Will the Collatz conjecture ever be solved?\")]\n",
@ -149,7 +139,7 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"If you need to pass additional parameters to the model, use `model_kwards` argument:" "If you need to pass additional parameters to the model, use `model_kwargs` argument:"
] ]
}, },
{ {
@ -162,7 +152,7 @@
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n", " endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions\",\n",
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n", " endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
" endpoint_api_key=\"my-api-key\",\n", " endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaChatContentFormatter,\n", " content_formatter=CustomOpenAIChatContentFormatter,\n",
" model_kwargs={\"temperature\": 0.8},\n", " model_kwargs={\"temperature\": 0.8},\n",
")" ")"
] ]
@ -204,7 +194,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.9.1"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -29,7 +29,7 @@
"You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n", "You must [deploy a model on Azure ML](https://learn.microsoft.com/en-us/azure/machine-learning/how-to-use-foundation-models?view=azureml-api-2#deploying-foundation-models-to-endpoints-for-inferencing) or [to Azure AI studio](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-open) and obtain the following parameters:\n",
"\n", "\n",
"* `endpoint_url`: The REST endpoint url provided by the endpoint.\n", "* `endpoint_url`: The REST endpoint url provided by the endpoint.\n",
"* `endpoint_api_type`: Use `endpoint_type='realtime'` when deploying models to **Realtime endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n", "* `endpoint_api_type`: Use `endpoint_type='dedicated'` when deploying models to **Dedicated endpoints** (hosted managed infrastructure). Use `endpoint_type='serverless'` when deploying models using the **Pay-as-you-go** offering (model as a service).\n",
"* `endpoint_api_key`: The API key provided by the endpoint.\n", "* `endpoint_api_key`: The API key provided by the endpoint.\n",
"* `deployment_name`: (Optional) The deployment name of the model using the endpoint." "* `deployment_name`: (Optional) The deployment name of the model using the endpoint."
] ]
@ -45,7 +45,7 @@
"* `GPT2ContentFormatter`: Formats request and response data for GPT2\n", "* `GPT2ContentFormatter`: Formats request and response data for GPT2\n",
"* `DollyContentFormatter`: Formats request and response data for the Dolly-v2\n", "* `DollyContentFormatter`: Formats request and response data for the Dolly-v2\n",
"* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n", "* `HFContentFormatter`: Formats request and response data for text-generation Hugging Face models\n",
"* `LLamaContentFormatter`: Formats request and response data for LLaMa2\n", "* `CustomOpenAIContentFormatter`: Formats request and response data for models like LLaMa2 that follow OpenAI API compatible scheme.\n",
"\n", "\n",
"*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*" "*Note: `OSSContentFormatter` is being deprecated and replaced with `GPT2ContentFormatter`. The logic is the same but `GPT2ContentFormatter` is a more suitable name. You can still continue to use `OSSContentFormatter` as the changes are backwards compatible.*"
] ]
@ -72,15 +72,15 @@
"source": [ "source": [
"from langchain_community.llms.azureml_endpoint import (\n", "from langchain_community.llms.azureml_endpoint import (\n",
" AzureMLEndpointApiType,\n", " AzureMLEndpointApiType,\n",
" LlamaContentFormatter,\n", " CustomOpenAIContentFormatter,\n",
")\n", ")\n",
"from langchain_core.messages import HumanMessage\n", "from langchain_core.messages import HumanMessage\n",
"\n", "\n",
"llm = AzureMLOnlineEndpoint(\n", "llm = AzureMLOnlineEndpoint(\n",
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n", " endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/score\",\n",
" endpoint_api_type=AzureMLEndpointApiType.realtime,\n", " endpoint_api_type=AzureMLEndpointApiType.dedicated,\n",
" endpoint_api_key=\"my-api-key\",\n", " endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaContentFormatter(),\n", " content_formatter=CustomOpenAIContentFormatter(),\n",
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n", " model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
")\n", ")\n",
"response = llm.invoke(\"Write me a song about sparkling water:\")\n", "response = llm.invoke(\"Write me a song about sparkling water:\")\n",
@ -119,7 +119,7 @@
"source": [ "source": [
"from langchain_community.llms.azureml_endpoint import (\n", "from langchain_community.llms.azureml_endpoint import (\n",
" AzureMLEndpointApiType,\n", " AzureMLEndpointApiType,\n",
" LlamaContentFormatter,\n", " CustomOpenAIContentFormatter,\n",
")\n", ")\n",
"from langchain_core.messages import HumanMessage\n", "from langchain_core.messages import HumanMessage\n",
"\n", "\n",
@ -127,7 +127,7 @@
" endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/completions\",\n", " endpoint_url=\"https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/completions\",\n",
" endpoint_api_type=AzureMLEndpointApiType.serverless,\n", " endpoint_api_type=AzureMLEndpointApiType.serverless,\n",
" endpoint_api_key=\"my-api-key\",\n", " endpoint_api_key=\"my-api-key\",\n",
" content_formatter=LlamaContentFormatter(),\n", " content_formatter=CustomOpenAIContentFormatter(),\n",
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n", " model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",
")\n", ")\n",
"response = llm.invoke(\"Write me a song about sparkling water:\")\n", "response = llm.invoke(\"Write me a song about sparkling water:\")\n",
@ -181,7 +181,7 @@
"content_formatter = CustomFormatter()\n", "content_formatter = CustomFormatter()\n",
"\n", "\n",
"llm = AzureMLOnlineEndpoint(\n", "llm = AzureMLOnlineEndpoint(\n",
" endpoint_api_type=\"realtime\",\n", " endpoint_api_type=\"dedicated\",\n",
" endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n", " endpoint_api_key=os.getenv(\"BART_ENDPOINT_API_KEY\"),\n",
" endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n", " endpoint_url=os.getenv(\"BART_ENDPOINT_URL\"),\n",
" model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n", " model_kwargs={\"temperature\": 0.8, \"max_new_tokens\": 400},\n",

View File

@ -1,16 +1,37 @@
import json import json
from typing import Any, Dict, List, Optional, cast import warnings
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
Type,
cast,
)
from langchain_core.callbacks.manager import CallbackManagerForLLMRun from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
BaseMessageChunk,
ChatMessage, ChatMessage,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessage, HumanMessage,
HumanMessageChunk,
SystemMessage, SystemMessage,
SystemMessageChunk,
ToolMessageChunk,
) )
from langchain_core.outputs import ChatGeneration, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_community.llms.azureml_endpoint import ( from langchain_community.llms.azureml_endpoint import (
AzureMLBaseEndpoint, AzureMLBaseEndpoint,
@ -25,12 +46,12 @@ class LlamaContentFormatter(ContentFormatterBase):
def __init__(self) -> None: def __init__(self) -> None:
raise TypeError( raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use " "`LlamaContentFormatter` is deprecated for chat models. Use "
"`LlamaChatContentFormatter` instead." "`CustomOpenAIContentFormatter` instead."
) )
class LlamaChatContentFormatter(ContentFormatterBase): class CustomOpenAIChatContentFormatter(ContentFormatterBase):
"""Content formatter for `LLaMA`.""" """Chat Content formatter for models with OpenAI like API scheme."""
SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"] SUPPORTED_ROLES: List[str] = ["user", "assistant", "system"]
@ -55,7 +76,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
} }
elif ( elif (
isinstance(message, ChatMessage) isinstance(message, ChatMessage)
and message.role in LlamaChatContentFormatter.SUPPORTED_ROLES and message.role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES
): ):
return { return {
"role": message.role, "role": message.role,
@ -63,7 +84,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
} }
else: else:
supported = ",".join( supported = ",".join(
[role for role in LlamaChatContentFormatter.SUPPORTED_ROLES] [role for role in CustomOpenAIChatContentFormatter.SUPPORTED_ROLES]
) )
raise ValueError( raise ValueError(
f"""Received unsupported role. f"""Received unsupported role.
@ -72,7 +93,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
@property @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
def format_messages_request_payload( def format_messages_request_payload(
self, self,
@ -82,10 +103,13 @@ class LlamaChatContentFormatter(ContentFormatterBase):
) -> bytes: ) -> bytes:
"""Formats the request according to the chosen api""" """Formats the request according to the chosen api"""
chat_messages = [ chat_messages = [
LlamaChatContentFormatter._convert_message_to_dict(message) CustomOpenAIChatContentFormatter._convert_message_to_dict(message)
for message in messages for message in messages
] ]
if api_type == AzureMLEndpointApiType.realtime: if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
request_payload = json.dumps( request_payload = json.dumps(
{ {
"input_data": { "input_data": {
@ -105,10 +129,13 @@ class LlamaChatContentFormatter(ContentFormatterBase):
def format_response_payload( def format_response_payload(
self, self,
output: bytes, output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime, api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> ChatGeneration: ) -> ChatGeneration:
"""Formats response""" """Formats response"""
if api_type == AzureMLEndpointApiType.realtime: if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
try: try:
choice = json.loads(output)["output"] choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
@ -143,6 +170,20 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError(f"`api_type` {api_type} is not supported by this formatter") raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class LlamaChatContentFormatter(CustomOpenAIChatContentFormatter):
"""Deprecated: Kept for backwards compatibility
Chat Content formatter for Llama."""
def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`LlamaChatContentFormatter` will be deprecated in the future.
Please use `CustomOpenAIChatContentFormatter` instead.
"""
)
class MistralChatContentFormatter(LlamaChatContentFormatter): class MistralChatContentFormatter(LlamaChatContentFormatter):
"""Content formatter for `Mistral`.""" """Content formatter for `Mistral`."""
@ -187,8 +228,8 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
Example: Example:
.. code-block:: python .. code-block:: python
azure_llm = AzureMLOnlineEndpoint( azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/v1/chat/completions",
endpoint_api_type=AzureMLApiType.realtime, endpoint_api_type=AzureMLApiType.serverless,
endpoint_api_key="my-api-key", endpoint_api_key="my-api-key",
content_formatter=chat_content_formatter, content_formatter=chat_content_formatter,
) )
@ -239,3 +280,143 @@ class AzureMLChatOnlineEndpoint(BaseChatModel, AzureMLBaseEndpoint):
response_payload, self.endpoint_api_type response_payload, self.endpoint_api_type
) )
return ChatResult(generations=[generations]) return ChatResult(generations=[generations])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
import openai
params = {}
client_params = {
"api_key": self.endpoint_api_key.get_secret_value(),
"base_url": self.endpoint_url,
"timeout": timeout,
"default_headers": None,
"default_query": None,
"http_client": None,
}
client = openai.OpenAI(**client_params)
message_dicts = [
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
for m in messages
]
params = {"stream": True, "stop": stop, "model": None, **kwargs}
default_chunk_class = AIMessageChunk
for chunk in client.chat.completions.create(messages=message_dicts, **params):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
self.endpoint_url = self.endpoint_url.replace("/chat/completions", "")
timeout = None if "timeout" not in kwargs else kwargs["timeout"]
import openai
params = {}
client_params = {
"api_key": self.endpoint_api_key.get_secret_value(),
"base_url": self.endpoint_url,
"timeout": timeout,
"default_headers": None,
"default_query": None,
"http_client": None,
}
async_client = openai.AsyncOpenAI(**client_params)
message_dicts = [
CustomOpenAIChatContentFormatter._convert_message_to_dict(m)
for m in messages
]
params = {"stream": True, "stop": stop, "model": None, **kwargs}
default_chunk_class = AIMessageChunk
async for chunk in await async_client.chat.completions.create(
messages=message_dicts, **params
):
if not isinstance(chunk, dict):
chunk = chunk.dict()
if len(chunk["choices"]) == 0:
continue
choice = chunk["choices"][0]
chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class
)
generation_info = {}
if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason
logprobs = choice.get("logprobs")
if logprobs:
generation_info["logprobs"] = logprobs
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk
def _convert_delta_to_message_chunk(
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
) -> BaseMessageChunk:
role = cast(str, _dict.get("role"))
content = cast(str, _dict.get("content") or "")
additional_kwargs: Dict = {}
if _dict.get("function_call"):
function_call = dict(_dict["function_call"])
if "name" in function_call and function_call["name"] is None:
function_call["name"] = ""
additional_kwargs["function_call"] = function_call
if _dict.get("tool_calls"):
additional_kwargs["tool_calls"] = _dict["tool_calls"]
if role == "user" or default_class == HumanMessageChunk:
return HumanMessageChunk(content=content)
elif role == "assistant" or default_class == AIMessageChunk:
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
elif role == "system" or default_class == SystemMessageChunk:
return SystemMessageChunk(content=content)
elif role == "function" or default_class == FunctionMessageChunk:
return FunctionMessageChunk(content=content, name=_dict["name"])
elif role == "tool" or default_class == ToolMessageChunk:
return ToolMessageChunk(content=content, tool_call_id=_dict["tool_call_id"])
elif role or default_class == ChatMessageChunk:
return ChatMessageChunk(content=content, role=role)
else:
return default_class(content=content)

View File

@ -62,12 +62,14 @@ class AzureMLEndpointClient(object):
class AzureMLEndpointApiType(str, Enum): class AzureMLEndpointApiType(str, Enum):
"""Azure ML endpoints API types. Use `realtime` for models deployed in hosted """Azure ML endpoints API types. Use `dedicated` for models deployed in hosted
infrastructure, or `serverless` for models deployed as a service with a infrastructure (also known as Online Endpoints in Azure Machine Learning),
or `serverless` for models deployed as a service with a
pay-as-you-go billing or PTU. pay-as-you-go billing or PTU.
""" """
realtime = "realtime" dedicated = "dedicated"
realtime = "realtime" #: Deprecated
serverless = "serverless" serverless = "serverless"
@ -141,13 +143,13 @@ class ContentFormatterBase:
deploying models using different hosting methods. Each method may have deploying models using different hosting methods. Each method may have
a different API structure.""" a different API structure."""
return [AzureMLEndpointApiType.realtime] return [AzureMLEndpointApiType.dedicated]
def format_request_payload( def format_request_payload(
self, self,
prompt: str, prompt: str,
model_kwargs: Dict, model_kwargs: Dict,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime, api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> Any: ) -> Any:
"""Formats the request body according to the input schema of """Formats the request body according to the input schema of
the model. Returns bytes or seekable file like object in the the model. Returns bytes or seekable file like object in the
@ -159,7 +161,7 @@ class ContentFormatterBase:
def format_response_payload( def format_response_payload(
self, self,
output: bytes, output: bytes,
api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime, api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated,
) -> Generation: ) -> Generation:
"""Formats the response body according to the output """Formats the response body according to the output
schema of the model. Returns the data type that is schema of the model. Returns the data type that is
@ -172,7 +174,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
@property @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime] return [AzureMLEndpointApiType.dedicated]
def format_request_payload( # type: ignore[override] def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
@ -214,7 +216,7 @@ class HFContentFormatter(ContentFormatterBase):
@property @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime] return [AzureMLEndpointApiType.dedicated]
def format_request_payload( # type: ignore[override] def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
@ -240,7 +242,7 @@ class DollyContentFormatter(ContentFormatterBase):
@property @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime] return [AzureMLEndpointApiType.dedicated]
def format_request_payload( # type: ignore[override] def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
@ -264,19 +266,22 @@ class DollyContentFormatter(ContentFormatterBase):
return Generation(text=choice) return Generation(text=choice)
class LlamaContentFormatter(ContentFormatterBase): class CustomOpenAIContentFormatter(ContentFormatterBase):
"""Content formatter for LLaMa""" """Content formatter for models that use the OpenAI like API scheme."""
@property @property
def supported_api_types(self) -> List[AzureMLEndpointApiType]: def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless] return [AzureMLEndpointApiType.dedicated, AzureMLEndpointApiType.serverless]
def format_request_payload( # type: ignore[override] def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes: ) -> bytes:
"""Formats the request according to the chosen api""" """Formats the request according to the chosen api"""
prompt = ContentFormatterBase.escape_special_characters(prompt) prompt = ContentFormatterBase.escape_special_characters(prompt)
if api_type == AzureMLEndpointApiType.realtime: if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
request_payload = json.dumps( request_payload = json.dumps(
{ {
"input_data": { "input_data": {
@ -297,7 +302,10 @@ class LlamaContentFormatter(ContentFormatterBase):
self, output: bytes, api_type: AzureMLEndpointApiType self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation: ) -> Generation:
"""Formats response""" """Formats response"""
if api_type == AzureMLEndpointApiType.realtime: if api_type in [
AzureMLEndpointApiType.dedicated,
AzureMLEndpointApiType.realtime,
]:
try: try:
choice = json.loads(output)[0]["0"] choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e: except (KeyError, IndexError, TypeError) as e:
@ -324,6 +332,22 @@ class LlamaContentFormatter(ContentFormatterBase):
raise ValueError(f"`api_type` {api_type} is not supported by this formatter") raise ValueError(f"`api_type` {api_type} is not supported by this formatter")
class LlamaContentFormatter(CustomOpenAIContentFormatter):
"""Deprecated: Kept for backwards compatibility
Content formatter for Llama."""
content_formatter: Any = None
def __init__(self) -> None:
super().__init__()
warnings.warn(
"""`LlamaContentFormatter` will be deprecated in the future.
Please use `CustomOpenAIContentFormatter` instead.
"""
)
class AzureMLBaseEndpoint(BaseModel): class AzureMLBaseEndpoint(BaseModel):
"""Azure ML Online Endpoint models.""" """Azure ML Online Endpoint models."""
@ -331,9 +355,9 @@ class AzureMLBaseEndpoint(BaseModel):
"""URL of pre-existing Endpoint. Should be passed to constructor or specified as """URL of pre-existing Endpoint. Should be passed to constructor or specified as
env var `AZUREML_ENDPOINT_URL`.""" env var `AZUREML_ENDPOINT_URL`."""
endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.realtime endpoint_api_type: AzureMLEndpointApiType = AzureMLEndpointApiType.dedicated
"""Type of the endpoint being consumed. Possible values are `serverless` for """Type of the endpoint being consumed. Possible values are `serverless` for
pay-as-you-go and `realtime` for real-time endpoints. """ pay-as-you-go and `dedicated` for dedicated endpoints. """
endpoint_api_key: SecretStr = convert_to_secret_str("") endpoint_api_key: SecretStr = convert_to_secret_str("")
"""Authentication Key for Endpoint. Should be passed to constructor or specified as """Authentication Key for Endpoint. Should be passed to constructor or specified as
@ -348,6 +372,8 @@ class AzureMLBaseEndpoint(BaseModel):
http_client: Any = None #: :meta private: http_client: Any = None #: :meta private:
max_retries: int = 1
content_formatter: Any = None content_formatter: Any = None
"""The content formatter that provides an input and output """The content formatter that provides an input and output
transform function to handle formats between the LLM and transform function to handle formats between the LLM and
@ -371,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
values, values,
"endpoint_api_type", "endpoint_api_type",
"AZUREML_ENDPOINT_API_TYPE", "AZUREML_ENDPOINT_API_TYPE",
AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.dedicated,
) )
values["timeout"] = get_from_dict_or_env( values["timeout"] = get_from_dict_or_env(
values, values,
@ -404,7 +430,7 @@ class AzureMLBaseEndpoint(BaseModel):
if field_value.endswith("inference.ml.azure.com"): if field_value.endswith("inference.ml.azure.com"):
raise ValueError( raise ValueError(
"`endpoint_url` should contain the full invocation URL including " "`endpoint_url` should contain the full invocation URL including "
"`/score` for `endpoint_api_type='realtime'` or `/v1/completions` " "`/score` for `endpoint_api_type='dedicated'` or `/v1/completions` "
"or `/v1/chat/completions` for `endpoint_api_type='serverless'`" "or `/v1/chat/completions` for `endpoint_api_type='serverless'`"
) )
return field_value return field_value
@ -415,11 +441,15 @@ class AzureMLBaseEndpoint(BaseModel):
) -> AzureMLEndpointApiType: ) -> AzureMLEndpointApiType:
"""Validate that endpoint api type is compatible with the URL format.""" """Validate that endpoint api type is compatible with the URL format."""
endpoint_url = values.get("endpoint_url") endpoint_url = values.get("endpoint_url")
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr] if (
"/score" (
field_value == AzureMLEndpointApiType.dedicated
or field_value == AzureMLEndpointApiType.realtime
)
and not endpoint_url.endswith("/score") # type: ignore[union-attr]
): ):
raise ValueError( raise ValueError(
"Endpoints of type `realtime` should follow the format " "Endpoints of type `dedicated` should follow the format "
"`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`." "`https://<your-endpoint>.<your_region>.inference.ml.azure.com/score`."
" If your endpoint URL ends with `/v1/completions` or" " If your endpoint URL ends with `/v1/completions` or"
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead." "`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
@ -461,7 +491,7 @@ class AzureMLOnlineEndpoint(BaseLLM, AzureMLBaseEndpoint):
.. code-block:: python .. code-block:: python
azure_llm = AzureMLOnlineEndpoint( azure_llm = AzureMLOnlineEndpoint(
endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score", endpoint_url="https://<your-endpoint>.<your_region>.inference.ml.azure.com/score",
endpoint_api_type=AzureMLApiType.realtime, endpoint_api_type=AzureMLApiType.dedicated,
endpoint_api_key="my-api-key", endpoint_api_key="my-api-key",
timeout=120, timeout=120,
content_formatter=content_formatter, content_formatter=content_formatter,

View File

@ -5,13 +5,15 @@ from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models.azureml_endpoint import ( from langchain_community.chat_models.azureml_endpoint import (
AzureMLChatOnlineEndpoint, AzureMLChatOnlineEndpoint,
LlamaChatContentFormatter, CustomOpenAIChatContentFormatter,
) )
def test_llama_call() -> None: def test_llama_call() -> None:
"""Test valid call to Open Source Foundation Model.""" """Test valid call to Open Source Foundation Model."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter()) chat = AzureMLChatOnlineEndpoint(
content_formatter=CustomOpenAIChatContentFormatter()
)
response = chat.invoke([HumanMessage(content="Foo")]) response = chat.invoke([HumanMessage(content="Foo")])
assert isinstance(response, BaseMessage) assert isinstance(response, BaseMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)
@ -19,7 +21,9 @@ def test_llama_call() -> None:
def test_temperature_kwargs() -> None: def test_temperature_kwargs() -> None:
"""Test that timeout kwarg works.""" """Test that timeout kwarg works."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter()) chat = AzureMLChatOnlineEndpoint(
content_formatter=CustomOpenAIChatContentFormatter()
)
response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8) response = chat.invoke([HumanMessage(content="FOO")], temperature=0.8)
assert isinstance(response, BaseMessage) assert isinstance(response, BaseMessage)
assert isinstance(response.content, str) assert isinstance(response.content, str)
@ -27,7 +31,9 @@ def test_temperature_kwargs() -> None:
def test_message_history() -> None: def test_message_history() -> None:
"""Test that multiple messages works.""" """Test that multiple messages works."""
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter()) chat = AzureMLChatOnlineEndpoint(
content_formatter=CustomOpenAIChatContentFormatter()
)
response = chat.invoke( response = chat.invoke(
[ [
HumanMessage(content="Hello."), HumanMessage(content="Hello."),
@ -40,7 +46,9 @@ def test_message_history() -> None:
def test_multiple_messages() -> None: def test_multiple_messages() -> None:
chat = AzureMLChatOnlineEndpoint(content_formatter=LlamaChatContentFormatter()) chat = AzureMLChatOnlineEndpoint(
content_formatter=CustomOpenAIChatContentFormatter()
)
message = HumanMessage(content="Hi!") message = HumanMessage(content="Hi!")
response = chat.generate([[message], [message]]) response = chat.generate([[message], [message]])

View File

@ -2,10 +2,10 @@ from langchain_community.llms.azureml_endpoint import (
AzureMLEndpointClient, AzureMLEndpointClient,
AzureMLOnlineEndpoint, AzureMLOnlineEndpoint,
ContentFormatterBase, ContentFormatterBase,
CustomOpenAIContentFormatter,
DollyContentFormatter, DollyContentFormatter,
GPT2ContentFormatter, GPT2ContentFormatter,
HFContentFormatter, HFContentFormatter,
LlamaContentFormatter,
OSSContentFormatter, OSSContentFormatter,
) )
@ -16,6 +16,6 @@ __all__ = [
"OSSContentFormatter", "OSSContentFormatter",
"HFContentFormatter", "HFContentFormatter",
"DollyContentFormatter", "DollyContentFormatter",
"LlamaContentFormatter", "CustomOpenAIContentFormatter",
"AzureMLOnlineEndpoint", "AzureMLOnlineEndpoint",
] ]